package com.ming.common.dynamic.spring.dynamic;

import com.sun.tools.javac.file.JavacFileManager;
import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.Log;
import org.springframework.boot.loader.jar.JarFile;

import javax.tools.*;
import java.io.*;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.jar.JarEntry;
import java.util.stream.Collectors;

/**
 * Java 文件管理器
 * 用于加载 SpringBoot 下面的依赖资源
 *
 * @author moon
 * @date 2023-08-10 9:58
 * @since 1.8
 */
public class MemoryJavaFileManager extends ForwardingJavaFileManager<JavaFileManager> {

    /**
     * 缓存字节码
     */
    final Map<String, byte[]> classBytesMap = new ConcurrentHashMap<>();

    /**
     * 缓存文件对象
     */
    final Map<String, List<JavaFileObject>> classObjectPackageMap = new ConcurrentHashMap<>();

    /**
     * 文件管理器
     */
    private JavacFileManager javaFileManager;

    /**
     * 包名 / JavaFile(.java)
     */
    public final static Map<String, List<JavaFileObject>> CLASS_OBJECT_PACKAGE_MAP = new ConcurrentHashMap<>();

    /**
     * 锁对象
     */
    private static final Object lock = new Object();

    /**
     * 初始化标识
     */
    private static boolean isInit = false;

    /**
     * 初始化
     */
    public void init() {
        try {
            JarFile tempJarFile;
            List<JavaFileObject> javaFiles;
            String packageName,className;
            //获取当前 Jar 包
            String jarBaseFile = MemoryClassLoader.getPath();
            //加载 Jar 包
            JarFile jarFile = new JarFile(new File(jarBaseFile));
            //取包自身文件
            for (JarEntry entry:jarFile){
                //SpringBoot repackage 打包 class 文件带一个 BOOT-INF/classes/ 之后才是包名
                String name = entry.getName().replace("BOOT-INF/classes/","");
                String classPath = name.replace("/", ".");
                //如果不是 class 文件跳过
                if (name.endsWith(".class")){
                    //取出包名
                    packageName = classPath.substring(0, name.lastIndexOf("/"));
                    //取类名
                    className = classPath.replace(".class", "");
                    //创建集合
                    javaFiles = Optional.ofNullable(CLASS_OBJECT_PACKAGE_MAP.get(packageName)).orElse(new ArrayList<>()) ;
                    //取 JavaFile
                    filterClass(packageName,className,jarFile.getUrl(),entry.getName(),javaFiles);
                }
            }
            //遍历取内部 Jar 包
            List<JarEntry> entries = jarFile.stream().filter(jarEntry -> {
                return jarEntry.getName().endsWith(".jar");
            }).collect(Collectors.toList());
            // Jar File
            for (JarEntry entry : entries) {
                if(entry.getName().endsWith("tools.jar")){
                    continue;
                }
                //取内部文件
                tempJarFile = jarFile.getNestedJarFile(jarFile.getEntry(entry.getName()));
                //跳过工具包 Jar
                if (tempJarFile.getName().contains("tools.jar")) {
                    continue;
                }
                //遍历 Jar 文件
                Enumeration<JarEntry> tempEntriesEnum = tempJarFile.entries();
                while (tempEntriesEnum.hasMoreElements()) {
                    JarEntry jarEntry = tempEntriesEnum.nextElement();
                    String classPath = jarEntry.getName().replace("/", ".");
                    //如果不是 class 文件跳过
                    if (!classPath.endsWith(".class") || jarEntry.getName().lastIndexOf("/") == -1) {
                        continue;
                    } else {
                        //取出包名
                        packageName = classPath.substring(0, jarEntry.getName().lastIndexOf("/"));
                        //取类名
                        className = jarEntry.getName().replace("/", ".").replace(".class", "");
                        //创建集合
                        javaFiles = Optional.ofNullable(CLASS_OBJECT_PACKAGE_MAP.get(packageName)).orElse(new ArrayList<>()) ;
                        //取 JavaFile
                        filterClass(packageName,className,tempJarFile.getUrl(),jarEntry.getName(),javaFiles);
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        isInit = true;
    }

    /**
     * 取 class
     * @param packageName
     * @param className
     * @param url
     * @param entryName
     * @param javaFiles
     */
    private void filterClass(String packageName,String className,URL url,String entryName,List<JavaFileObject> javaFiles) throws MalformedURLException {
        //取 JavaFile
        javaFiles.add(new MemorySpringBootInfoJavaClassObject(className, new URL(url, entryName), javaFileManager));
        //缓存 Package / JavaFile
        CLASS_OBJECT_PACKAGE_MAP.put(packageName, javaFiles);
    }

    /**
     * 构造
     */
    MemoryJavaFileManager() {
        super(getStandardFileManager(null, null, null));
        this.javaFileManager = (JavacFileManager) fileManager;
    }

    /**
     * 获取文件对象集合
     * @param packageName
     * @return
     */
    public List<JavaFileObject> getLibJarsOptions(String packageName) {
        synchronized (lock) {
            if (!isInit) {
                init();
            }
        }
        return CLASS_OBJECT_PACKAGE_MAP.get(packageName);
    }

    @Override
    public Iterable<JavaFileObject> list(Location location,
                                         String packageName,
                                         Set<JavaFileObject.Kind> kinds,
                                         boolean recurse)
            throws IOException {


        if ("CLASS_PATH".equals(location.getName()) && MemoryClassLoader.isJar()) {
            List<JavaFileObject> result = getLibJarsOptions(packageName);
            if (result != null) {
                return result;
            }
        }

        Iterable<JavaFileObject> it = super.list(location, packageName, kinds, recurse);

        if (kinds.contains(JavaFileObject.Kind.CLASS)) {
            final List<JavaFileObject> javaFileObjectList = classObjectPackageMap.get(packageName);
            if (javaFileObjectList != null) {
                if (it != null) {
                    for (JavaFileObject javaFileObject : it) {
                        javaFileObjectList.add(javaFileObject);
                    }
                }
                return javaFileObjectList;
            } else {
                return it;
            }
        } else {
            return it;
        }
    }

    @Override
    public String inferBinaryName(Location location, JavaFileObject file) {
        if (file instanceof MemoryInputJavaClassObject) {
            return ((MemoryInputJavaClassObject) file).inferBinaryName();
        }
        return super.inferBinaryName(location, file);
    }

    @Override
    public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind,
                                               FileObject sibling) throws IOException {
        if (kind == JavaFileObject.Kind.CLASS) {
            return new MemoryOutputJavaClassObject(className);
        } else {
            return super.getJavaFileForOutput(location, className, kind, sibling);
        }
    }

    /**
     * 设置源码
     * @param className
     * @param code
     * @return
     */
    JavaFileObject makeStringSource(String className, final String code) {
        String classPath = className.replace('.', '/') + JavaFileObject.Kind.SOURCE.extension;

        return new SimpleJavaFileObject(URI.create("string:///" + classPath), JavaFileObject.Kind.SOURCE) {
            @Override
            public CharBuffer getCharContent(boolean ignoreEncodingErrors) {
                return CharBuffer.wrap(code);
            }
        };
    }

    /**
     * 设置字节码
     * @param className
     * @param bs
     */
    void makeBinaryClass(String className, final byte[] bs) {
        JavaFileObject javaFileObject = new MemoryInputJavaClassObject(className, bs);
        String packageName = "";
        int pos = className.lastIndexOf('.');
        if (pos > 0) {
            packageName = className.substring(0, pos);
        }
        List<JavaFileObject> javaFileObjectList = classObjectPackageMap.get(packageName);
        if (javaFileObjectList == null) {
            javaFileObjectList = new LinkedList<>();
            javaFileObjectList.add(javaFileObject);

            classObjectPackageMap.put(packageName, javaFileObjectList);
        } else {
            javaFileObjectList.add(javaFileObject);
        }
    }

    /**
     * 内部输入类
     */
    class MemoryInputJavaClassObject extends SimpleJavaFileObject {
        final String className;
        final byte[] bs;

        MemoryInputJavaClassObject(String className, byte[] bs) {
            super(URI.create("string:///" + className.replace('.', '/') + Kind.CLASS.extension), Kind.CLASS);
            this.className = className;
            this.bs = bs;
        }

        @Override
        public InputStream openInputStream() {
            return new ByteArrayInputStream(bs);
        }

        public String inferBinaryName() {
            return className;
        }
    }

    /**
     * 内部输出类
     */
    class MemoryOutputJavaClassObject extends SimpleJavaFileObject {
        final String className;

        MemoryOutputJavaClassObject(String className) {
            super(URI.create("string:///" + className.replace('.', '/') + Kind.CLASS.extension), Kind.CLASS);
            this.className = className;
        }

        @Override
        public OutputStream openOutputStream() {
            return new FilterOutputStream(new ByteArrayOutputStream()) {
                @Override
                public void close() throws IOException {
                    out.close();
                    ByteArrayOutputStream bos = (ByteArrayOutputStream) out;
                    byte[] bs = bos.toByteArray();
                    classBytesMap.put(className, bs);
                    makeBinaryClass(className, bs);
                }
            };
        }
    }

    /**
     * 获取编译结果
     * @return
     */
    public Map<String, byte[]> getClassBytes() {
        return new HashMap<>(this.classBytesMap);
    }

    /**
     * 刷新
     * @throws IOException
     */
    @Override
    public void flush() throws IOException {
    }

    /**
     * 关闭
     * @throws IOException
     */
    @Override
    public void close() throws IOException {
        classBytesMap.clear();
    }

    /**
     * 自定义 Java 文件管理器
     *
     * @param var1
     * @param var2
     * @param var3
     * @return
     */
    public static SpringJavaFileManager getStandardFileManager(DiagnosticListener<? super JavaFileObject> var1, Locale var2, Charset var3) {
        Context var4 = new Context();
        var4.put(Locale.class, var2);
        if (var1 != null) {
            var4.put(DiagnosticListener.class, var1);
        }
        PrintWriter var5 = var3 == null ? new PrintWriter(System.err, true) : new PrintWriter(new OutputStreamWriter(System.err, var3), true);
        var4.put(Log.outKey, var5);
        return new SpringJavaFileManager(var4, true, var3);
    }
}
